{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 序列逆置\n",
    "使用sequence to sequence 模型将一个字符串序列逆置。\n",
    "例如 `OIMESIQFIQ` 逆置成 `QIFQISEMIO`(下图来自网络，是一个sequence to sequence 模型示意图 )\n",
    "![seq2seq](./seq2seq.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import layers, optimizers, datasets\n",
    "import os,sys,tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 玩具序列数据生成\n",
    "生成只包含[A-Z]的字符串，并且将encoder输入以及decoder输入以及decoder输出准备好（转成index）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(['IREBUBLOTS', 'IWEPUFQWJD'], <tf.Tensor: id=0, shape=(2, 10), dtype=int32, numpy=\n",
      "array([[ 9, 18,  5,  2, 21,  2, 12, 15, 20, 19],\n",
      "       [ 9, 23,  5, 16, 21,  6, 17, 23, 10,  4]], dtype=int32)>, <tf.Tensor: id=1, shape=(2, 10), dtype=int32, numpy=\n",
      "array([[ 0, 19, 20, 15, 12,  2, 21,  2,  5, 18],\n",
      "       [ 0,  4, 10, 23, 17,  6, 21, 16,  5, 23]], dtype=int32)>, <tf.Tensor: id=2, shape=(2, 10), dtype=int32, numpy=\n",
      "array([[19, 20, 15, 12,  2, 21,  2,  5, 18,  9],\n",
      "       [ 4, 10, 23, 17,  6, 21, 16,  5, 23,  9]], dtype=int32)>)\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import string\n",
    "\n",
    "def randomString(stringLength):\n",
    "    \"\"\"Generate a random string with the combination of lowercase and uppercase letters \"\"\"\n",
    "\n",
    "    letters = string.ascii_uppercase\n",
    "    return ''.join(random.choice(letters) for i in range(stringLength))\n",
    "\n",
    "def get_batch(batch_size, length):\n",
    "    batched_examples = [randomString(length) for i in range(batch_size)]\n",
    "    enc_x = [[ord(ch)-ord('A')+1 for ch in list(exp)] for exp in batched_examples]\n",
    "    y = [[o for o in reversed(e_idx)] for e_idx in enc_x]\n",
    "    dec_x = [[0]+e_idx[:-1] for e_idx in y]\n",
    "    return (batched_examples, tf.constant(enc_x, dtype=tf.int32), \n",
    "            tf.constant(dec_x, dtype=tf.int32), tf.constant(y, dtype=tf.int32))\n",
    "print(get_batch(2, 10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 建立sequence to sequence 模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class mySeq2SeqModel(keras.Model):\n",
    "    def __init__(self):\n",
    "        super(mySeq2SeqModel, self).__init__()\n",
    "        self.v_sz=27\n",
    "        self.embed_layer = tf.keras.layers.Embedding(self.v_sz, 64, \n",
    "                                                    batch_input_shape=[None, None])\n",
    "        \n",
    "        self.encoder_cell = tf.keras.layers.SimpleRNNCell(128)\n",
    "        self.decoder_cell = tf.keras.layers.SimpleRNNCell(128)\n",
    "        \n",
    "        self.encoder = tf.keras.layers.RNN(self.encoder_cell, \n",
    "                                           return_sequences=True, return_state=True)\n",
    "        self.decoder = tf.keras.layers.RNN(self.decoder_cell, \n",
    "                                           return_sequences=True, return_state=True)\n",
    "        self.dense = tf.keras.layers.Dense(self.v_sz)\n",
    "        \n",
    "    @tf.function\n",
    "    def call(self, enc_ids, dec_ids):\n",
    "        '''\n",
    "        完成sequence2sequence 模型的搭建，模块已经在`__init__`函数中定义好\n",
    "        '''\n",
    "        return logits\n",
    "    \n",
    "    \n",
    "#     @tf.function\n",
    "    def encode(self, enc_ids):\n",
    "        enc_emb = self.embed_layer(enc_ids) # shape(b_sz, len, emb_sz)\n",
    "        enc_out, enc_state = self.encoder(enc_emb)\n",
    "        \n",
    "        return [enc_out[:, -1, :], enc_state]\n",
    "    \n",
    "    def get_next_token(self, x, state):\n",
    "        '''\n",
    "        shape(x) = [b_sz,] \n",
    "        '''\n",
    "        inp_emb = self.embed_layer(x) #shape(b_sz, emb_sz)\n",
    "        h, state = self.decoder_cell.call(inp_emb, state) # shape(b_sz, h_sz)\n",
    "        logits = self.dense(h) # shape(b_sz, v_sz)\n",
    "        out = tf.argmax(logits, axis=-1)\n",
    "        return out, state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loss函数以及训练逻辑"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def compute_loss(logits, labels):\n",
    "    losses = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=labels)\n",
    "    losses = tf.reduce_mean(losses)\n",
    "    return losses\n",
    "\n",
    "@tf.function\n",
    "def train_one_step(model, optimizer, enc_x, dec_x, y):\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model(enc_x, dec_x)\n",
    "        loss = compute_loss(logits, y)\n",
    "\n",
    "    # compute gradient\n",
    "    grads = tape.gradient(loss, model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "    return loss\n",
    "\n",
    "def train(model, optimizer, seqlen):\n",
    "    loss = 0.0\n",
    "    accuracy = 0.0\n",
    "    for step in range(3000):\n",
    "        batched_examples, enc_x, dec_x, y = get_batch(32, seqlen)\n",
    "        loss = train_one_step(model, optimizer, enc_x, dec_x, y)\n",
    "        if step % 500 == 0:\n",
    "            print('step', step, ': loss', loss.numpy())\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练迭代"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0 : loss 3.3078957\n",
      "step 500 : loss 1.6194503\n",
      "step 1000 : loss 1.1144965\n",
      "step 1500 : loss 0.8585207\n",
      "step 2000 : loss 0.69941413\n",
      "step 2500 : loss 0.6125215\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=14328, shape=(), dtype=float32, numpy=0.4644309>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "optimizer = optimizers.Adam(0.0005)\n",
    "model = mySeq2SeqModel()\n",
    "train(model, optimizer, seqlen=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试模型逆置能力\n",
    "首先要先对输入的一个字符串进行encode，然后在用decoder解码出逆置的字符串\n",
    "\n",
    "测试阶段跟训练阶段的区别在于，在训练的时候decoder的输入是给定的，而在预测的时候我们需要一步步生成下一步的decoder的输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]\n",
      "[('OIMESIQFIQ', 'QIFQISEMIO'), ('MCFHOXNWHS', 'SHWNXOHFCM'), ('BYGBKVOAQT', 'TQAOVKBGYB'), ('RHSOVKGUWI', 'IWUGKVOSHR'), ('LKHWBGZXBO', 'OBXZGBWHKL'), ('ERYQCZPWKB', 'BKWPZCQYRE'), ('ZIQQHPBJTR', 'RTJBPHQQIZ'), ('AQFDDTPZQH', 'HQZPTDDFQA'), ('PRWJBEGJEK', 'QMJQEEJWRP'), ('MRAWVSFNZL', 'LZNFSVWARM'), ('CLFVWJRCJV', 'HQBIJWZFLC'), ('WARGSOPMKZ', 'ZKMPOSGRAW'), ('OLWBOACXQY', 'YQXCAOBWLO'), ('VBKDBZCRIP', 'PIRCZBDKBV'), ('BRWCJPWTWQ', 'QWTWPJCWRB'), ('ZARPMFTSES', 'SESTFMPRAZ'), ('VXGVVJFUBR', 'RBUFJVVGXV'), ('ETAQRUYCCQ', 'QCCYURQATE'), ('BIOVWZODCI', 'ICDOZWVOIB'), ('RLOAZAKWEP', 'PEWKAZAOLR'), ('YRMRDECTVZ', 'ZVTCEDRMRY'), ('WMPHBAHFGJ', 'JGFHABHPMW'), ('LEAHJTFHEG', 'GEHFTJHAEL'), ('WIHALQECAX', 'XACEQLAHIW'), ('XYDOZHQOTE', 'ETOQHZODYX'), ('XEVMHNNBRS', 'SRBNNHMVEX'), ('JCKCQRHZYP', 'PYZHRQCKCJ'), ('FZZROYFMUR', 'RUMFYORZZF'), ('IMDNMAIJEH', 'HEJIAMNDMI'), ('LHNWNFQNQH', 'ZQNQJNONHL'), ('MPLVUUCUUG', 'GUUCUUVLPM'), ('AAFWQGSHWO', 'OWHSGQWFAA')]\n"
     ]
    }
   ],
   "source": [
    "def sequence_reversal():\n",
    "    def decode(init_state, steps=10):\n",
    "        b_sz = tf.shape(init_state[0])[0]\n",
    "        cur_token = tf.zeros(shape=[b_sz], dtype=tf.int32)\n",
    "        state = init_state\n",
    "        collect = []\n",
    "        for i in range(steps):\n",
    "            cur_token, state = model.get_next_token(cur_token, state)\n",
    "            collect.append(tf.expand_dims(cur_token, axis=-1))\n",
    "        out = tf.concat(collect, axis=-1).numpy()\n",
    "        out = [''.join([chr(idx+ord('A')-1) for idx in exp]) for exp in out]\n",
    "        return out\n",
    "    \n",
    "    batched_examples, enc_x, _, _ = get_batch(32, 10)\n",
    "    state = model.encode(enc_x)\n",
    "    return decode(state, enc_x.get_shape()[-1]), batched_examples\n",
    "\n",
    "def is_reverse(seq, rev_seq):\n",
    "    rev_seq_rev = ''.join([i for i in reversed(list(rev_seq))])\n",
    "    if seq == rev_seq_rev:\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "print([is_reverse(*item) for item in list(zip(*sequence_reversal()))])\n",
    "print(list(zip(*sequence_reversal())))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
